{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Variational Autoencoder with Normalizing Flows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import packages\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "from tqdm import tqdm\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "import normflows as nf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyperparameters\n",
    "torch.manual_seed(0)\n",
    "\n",
    "batch_size = 64\n",
    "num_samples = 32\n",
    "n_flows = 40\n",
    "n_bottleneck = 40\n",
    "hidden_units_encoder = np.array([28 ** 2, 512, 256, n_bottleneck * 2])\n",
    "hidden_units_decoder = np.array([n_bottleneck, 256, 512, 28 ** 2])\n",
    "flow_type = 'Planar'\n",
    "n_epochs = 15\n",
    "\n",
    "enable_cuda = True\n",
    "device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get dataloader\n",
    "class BinaryTransform():\n",
    "    def __init__(self, thresh=0.5):\n",
    "        self.thresh = thresh\n",
    "\n",
    "    def __call__(self, x):\n",
    "        return (x > self.thresh).type(x.type())\n",
    "transform=transforms.Compose([transforms.ToTensor(), BinaryTransform()])\n",
    "mnist_train = datasets.MNIST('../datasets', train=True, download=True, transform=transform)\n",
    "mnist_test = datasets.MNIST('../datasets', train=False, download=True, transform=transform)\n",
    "train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up model\n",
    "prior = torch.distributions.MultivariateNormal(torch.zeros(n_bottleneck, device=device),\n",
    "                                               torch.eye(n_bottleneck, device=device))\n",
    "encoder_nn = nf.nets.MLP(hidden_units_encoder)\n",
    "encoder = nf.distributions.NNDiagGaussian(encoder_nn)\n",
    "decoder_nn = nf.nets.MLP(hidden_units_decoder)\n",
    "decoder = nf.distributions.NNBernoulliDecoder(decoder_nn)\n",
    "\n",
    "if flow_type == 'Planar':\n",
    "    flows = [nf.flows.Planar((n_bottleneck,)) for k in range(n_flows)]\n",
    "elif flow_type == 'Radial':\n",
    "    flows = [nf.flows.Radial((n_bottleneck,)) for k in range(n_flows)]\n",
    "elif flow_type == 'RealNVP':\n",
    "    b = torch.tensor(n_bottleneck // 2 * [0, 1] + n_bottleneck % 2 * [0])\n",
    "    flows = []\n",
    "    for i in range(n_flows):\n",
    "        s = nf.nets.MLP([n_bottleneck, n_bottleneck])\n",
    "        t = nf.nets.MLP([n_bottleneck, n_bottleneck])\n",
    "        if i % 2 == 0:\n",
    "            flows += [nf.flows.MaskedAffineFlow(b, t, s)]\n",
    "        else:\n",
    "            flows += [nf.flows.MaskedAffineFlow(1 - b, t, s)]\n",
    "else:\n",
    "    raise NotImplementedError\n",
    "\n",
    "nfm = nf.NormalizingFlowVAE(prior, encoder, flows, decoder)\n",
    "nfm.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Train model\n",
    "log_intv = 100\n",
    "optimizer = optim.Adam(nfm.parameters(), lr=1e-4, weight_decay=1e-4)\n",
    "for epoch in range(n_epochs):\n",
    "    progressbar = tqdm(enumerate(train_loader), total=len(train_loader))\n",
    "    for batch_n, (x, n) in progressbar:\n",
    "        x = x.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        z, log_q, log_p = nfm(x.view(x.size(0) * x.size(1), 28 ** 2), num_samples)\n",
    "        mean_log_q = torch.mean(log_q)\n",
    "        mean_log_p = torch.mean(log_p)\n",
    "        loss = mean_log_q - mean_log_p\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        progressbar.update()\n",
    "    progressbar.close()\n",
    "    print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
    "                epoch, batch_n * len(x), len(train_loader.dataset),\n",
    "                       100. * batch_n / len(train_loader),\n",
    "                       loss.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_out = nfm.decoder(torch.randn((1, n_bottleneck), device=device))\n",
    "x_np = x_out.view((28, 28)).to('cpu').detach().numpy()\n",
    "plt.imshow(x_np)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
